{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Env"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cd .."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from tot.prompts.crosswords import propose_prompt, value_prompt\n",
    "from tot.models import gpt\n",
    "from tot.tasks.crosswords import MiniCrosswordsEnv\n",
    "\n",
    "env = MiniCrosswordsEnv()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def prompt_wrap(obs):\n",
    "    return propose_prompt.format(input=obs)\n",
    "\n",
    "print(prompt_wrap(env.reset(0)))\n",
    "# print('---------')\n",
    "# print(prompt_wrap(env.step('h2. value')[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import copy\n",
    "from tot.models import gpt\n",
    "\n",
    "def parse_line(input_str):\n",
    "    # regular expression pattern to match the input string format\n",
    "    pattern = r'^([hv][1-5])\\. ([a-zA-Z]{5,5}) \\((certain|high|medium|low)\\).*$'\n",
    "\n",
    "    # use regex to extract the parts of the input string\n",
    "    match = re.match(pattern, input_str)\n",
    "\n",
    "    if match:\n",
    "        # extract the matched groups\n",
    "        parts = [match.group(1), match.group(2), match.group(3)]\n",
    "        return parts\n",
    "    else:\n",
    "        return None\n",
    "\n",
    "confidence_to_value = {'certain': 1, 'high': 0.5, 'medium': 0.2, 'low': 0.1}  # TODO: ad hoc\n",
    "\n",
    "def parse_response(response):\n",
    "    # split the response into lines\n",
    "    lines = response.split('\\n')\n",
    "\n",
    "    # parse each line\n",
    "    parsed_lines = [parse_line(line) for line in lines]\n",
    "\n",
    "    # filter out the lines that didn't match the format\n",
    "    parsed_lines = [(line[0].lower() + '. ' + line[1].lower(), confidence_to_value.get(line[2], 0)) for line in parsed_lines if line is not None]\n",
    "\n",
    "    return parsed_lines if len(parsed_lines) >= 1 else None\n",
    "\n",
    "\n",
    "def get_candidates_to_scores(env):\n",
    "    obs = env.render()\n",
    "    if obs in env.cache: \n",
    "        print('cache hit')\n",
    "        return env.cache[obs]\n",
    "    print('call gpt')\n",
    "    responses = gpt(prompt_wrap(obs), model='gpt-4', n=8)\n",
    "    candidates_to_scores = {}\n",
    "    for response in responses:\n",
    "        parsed_response = parse_response(response)\n",
    "        if parsed_response:\n",
    "            for candidate, score in parsed_response:\n",
    "                candidates_to_scores[candidate] = candidates_to_scores.get(candidate, 0) + score\n",
    "        # choose candiate with highest score\n",
    "    # print(sorted(candidates_to_scores.items(), key=lambda x: x[1], reverse=True))\n",
    "    env.cache[obs] = candidates_to_scores\n",
    "    return candidates_to_scores\n",
    "\n",
    "def propose_score(env, idx):\n",
    "    obs = env.reset(idx)\n",
    "    done = False\n",
    "    infos = []\n",
    "    while not done:\n",
    "        responses = gpt(prompt_wrap(obs), model='gpt-4', n=5)\n",
    "        candidates_to_scores = {}\n",
    "        for response in responses:\n",
    "            parsed_response = parse_response(response)\n",
    "            if parsed_response:\n",
    "                for candidate, score in parsed_response:\n",
    "                    candidates_to_scores[candidate] = candidates_to_scores.get(candidate, 0) + score\n",
    "        # choose candiate with highest score\n",
    "        print(sorted(candidates_to_scores.items(), key=lambda x: x[1], reverse=True))\n",
    "        if len(candidates_to_scores) == 0:\n",
    "            break\n",
    "        candidates =  sorted(candidates_to_scores, key=candidates_to_scores.get, reverse=True)\n",
    "        for candidate in candidates:\n",
    "            env_ = copy.deepcopy(env)\n",
    "            env_.step(candidate)\n",
    "            if not any(_ == 2 for _ in env_.status):\n",
    "                break\n",
    "        print(candidate)\n",
    "        # candidate = input()\n",
    "        obs, r, done, info = env.step(candidate)\n",
    "        print(obs)\n",
    "        print(env.steps, info)\n",
    "        print('-------------------\\n\\n\\n')\n",
    "        infos.append(info)\n",
    "    return infos"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DFS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dfs(env, actions, infos, time_limit, prune, max_per_state):\n",
    "    # get candidate thoughts\n",
    "    candidates_to_scores = get_candidates_to_scores(env)\n",
    "    if len(candidates_to_scores) == 0: return 0, [], []\n",
    "    print(sorted(candidates_to_scores.items(), key=lambda x: x[1], reverse=True))\n",
    "\n",
    "    # back up current state\n",
    "    board, status, steps = env.board.copy(), env.status.copy(), env.steps\n",
    "\n",
    "    # try each candidate\n",
    "    cnt_per_state = 0\n",
    "    for action in sorted(candidates_to_scores, key=candidates_to_scores.get, reverse=True):\n",
    "        obs, r, done, info = env.step(action)\n",
    "        r = info['r_word']\n",
    "        if len(infos) < time_limit and env.steps < 10 and not any(_ == 2 for _ in env.status):  # not violating any existing constraints\n",
    "            cnt_per_state += 1\n",
    "            if cnt_per_state > max_per_state: break\n",
    "            count = env.prompt_status()       \n",
    "            actions.append(action)  \n",
    "\n",
    "            print(len(infos))\n",
    "            print(actions)\n",
    "            print(env.render_board())\n",
    "            print(info)\n",
    "            print(count)\n",
    "            if infos:\n",
    "                best = max(infos, key=lambda x: x['info']['r_word'])\n",
    "                print('best', best)\n",
    "            print('--------------')\n",
    "            print()\n",
    "\n",
    "            info = {'total_step': len(infos), 'env_step': env.steps, 'actions': actions.copy(), 'info': info, 'count': count}\n",
    "            infos.append(info)\n",
    "            if not prune or count['impossible'] < 1:  # only continue if the current status is possible\n",
    "                dfs(env, actions, infos, time_limit, prune, max_per_state)\n",
    "            actions.pop()\n",
    "        env.reset(env.idx, board=board.copy(), status=status.copy(), steps=steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dfs with pruning\n",
    "infoss = []\n",
    "for i in range(0, 100, 5):\n",
    "    env.reset(i)\n",
    "    infos = []\n",
    "    actions = []\n",
    "    dfs(env, actions, infos, 100, prune=True, max_per_state=3)\n",
    "    infoss.append(infos)\n",
    "    with open('logs/crosswords/infoss_dfs_prune.json', 'w') as fout:\n",
    "        json.dump(infoss, fout)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dfs without pruning\n",
    "infoss = []\n",
    "for i in range(0, 100, 5):\n",
    "    env.reset(i)\n",
    "    infos = []\n",
    "    actions = []\n",
    "    dfs(env, actions, infos, 100, prune=False, max_per_state=3)\n",
    "    infoss.append(infos)\n",
    "    with open('logs/crosswords/infoss_dfs_no_prune.json', 'w') as fout:\n",
    "        json.dump(infoss, fout)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
